# -*- coding: utf-8 -*-
"""
robustness_panels.py — Robustness checks for MPA misfit panels
================================================================
This standalone script generates TWO additional robustness datasets:
  1) Full-sample re-estimation (per year) — "alaam_nodes_panel__fullsample.csv" (+ 2015 labels)
  2) Survivor-2015 time-tracking — using 2015 survivor set across years
     "alaam_nodes_panel__survivors2015.csv" (+ 2015 labels)
It reads the *outputs generated by your main pipeline* (nodes & matrices)
from ./outputs_alaam/ by default, so it does **not** modify the main script.

I/O (defaults)
--------------
Base input dir (from main pipeline):   ./outputs_alaam/
  - nodes/misfit_dual_diff_{year}.csv
  - matrices/inst_Astar_{year}_public.csv

Outputs (written back to base dir):
  - alaam/y2015_labels__fullsample.csv
  - alaam/alaam_nodes_panel__fullsample.csv
  - alaam/y2015_labels__survivors2015.csv
  - alaam/alaam_nodes_panel__survivors2015.csv
  - alaam/misfit_nodes_all_years__survivors2015.csv
  - summary/survivors2015_counts.csv

Parameters
----------
--base-dir         Path to the base outputs folder (default: ./outputs_alaam)
--high-pctl        Critical threshold per sea-region (default: 0.85, i.e., Q85)
--sc-gate-q        SC-degree gate per sea-region (default: 0.40, i.e., Q40)
--abs-baseline     Minimum M threshold floor (default: 0.10)
--strict-intersect Require strict survivors (2015∩2020∩2025) for labels (default: True)
"""

import argparse
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import networkx as nx

YEARS = [2015, 2020, 2025]
PANEL_TARGET_YEARS = [2020, 2025]

def triangles_per_node(A: np.ndarray) -> np.ndarray:
    """Count undirected triangles per node using neighbor intersection."""
    n = A.shape[0]
    tris = np.zeros(n, dtype=int)
    nbrs = [set(np.where(A[i] > 0)[0]) for i in range(n)]
    for i in range(n):
        Ni = nbrs[i]; cnt = 0
        for u in Ni:
            cnt += len(Ni & nbrs[u])
        tris[i] = cnt // 2
    return tris

def zscore(s: pd.Series) -> pd.Series:
    mu = s.mean()
    sd = s.std(ddof=0)
    if sd == 0 or np.isnan(sd):
        return (s - mu)
    return (s - mu) / sd

def load_nodes(base_dir: Path, year: int) -> pd.DataFrame:
    p = base_dir / "nodes" / f"misfit_dual_diff_{year}.csv"
    df = pd.read_csv(p)
    if 'MPA_ID' not in df.columns:
        raise FileNotFoundError(f"'MPA_ID' not in nodes file: {p}")
    df['MPA_ID'] = df['MPA_ID'].astype(str)
    # Ensure required columns exist
    need = ['M','SC_deg','SeaRegion','province']
    for c in need:
        if c not in df.columns:
            raise ValueError(f"Missing '{c}' in nodes file: {p}")
    return df

def load_A_prev(base_dir: Path, year: int) -> (np.ndarray, list):
    """Load institutional strong-edge binary matrix (A*) for the given year and return (A, ids)."""
    p = base_dir / "matrices" / f"inst_Astar_{year}_public.csv"
    M = pd.read_csv(p, index_col=0)
    ids = list(M.index.astype(str))
    A = M.values.astype(float)
    A = (A > 0).astype(int)
    # enforce symmetric & zero diagonal
    A = np.maximum(A, A.T)
    np.fill_diagonal(A, 0)
    return A, ids

def compute_labels_for_year(df_year: pd.DataFrame, high_pctl: float, sc_gate_q: float, abs_baseline: float) -> pd.DataFrame:
    """Compute SC_gate and Y_t (critical & SC-gated) per sea-region for a given full-sample year."""
    out = df_year.copy()
    # SC gate per sea-region on SC_deg
    gates = out.groupby('SeaRegion')['SC_deg'].quantile(sc_gate_q)
    out['SC_gate'] = out.apply(lambda r: r['SC_deg'] >= gates.get(r['SeaRegion'], r['SC_deg']), axis=1).astype(int)
    # Critical threshold per sea-region on M, floored by abs_baseline
    tau_by_sea = out.groupby('SeaRegion')['M'].quantile(high_pctl).to_dict()
    def _crit(r):
        tau = max(abs_baseline, tau_by_sea.get(r['SeaRegion'], 0.0))
        return int((r['M'] > tau) and (r['SC_gate'] == 1))
    out['Y_t'] = out.apply(_crit, axis=1).astype(int)
    return out

def build_struct_features_from_prev(A_prev: np.ndarray, ids_prev: list) -> dict:
    """Compute structural covariates from previous year's A* network."""
    n_prev = A_prev.shape[0]
    deg_prev = pd.Series(A_prev.sum(axis=1), index=ids_prev, dtype=float)
    logdeg_prev = np.log1p(deg_prev)

    tri_prev = pd.Series(triangles_per_node(A_prev), index=ids_prev, dtype=float)

    two_open_prev = pd.Series(
        (A_prev.dot(A_prev) * (1 - A_prev) * (1 - np.eye(n_prev))).sum(axis=1),
        index=ids_prev, dtype=float
    )

    G_prev = nx.from_numpy_array(A_prev)
    bw_dict = nx.betweenness_centrality(G_prev, normalized=True)
    bw_prev = pd.Series({ids_prev[i]: bw for i, (_, bw) in enumerate(bw_dict.items())})
    bw_prev_z = zscore(bw_prev)

    return {
        'logdeg_t1': logdeg_prev,
        'triangles_t1': tri_prev,
        'two_path_open_t1': two_open_prev,
        'corridor_betweenness_t1': bw_prev,
        'corridor_betweenness_t1_z': bw_prev_z,
    }

def high_misfit_prev_indicator(df_prev: pd.DataFrame) -> pd.Series:
    """Binary indicator for previous-year high misfit with tau_high = max(Q3_positive, 0.60)."""
    prev_pos = df_prev['M'][df_prev['M'] > 0]
    Q3_prev = float(prev_pos.quantile(0.75)) if len(prev_pos) > 0 else 0.0
    tau_high = max(Q3_prev, 0.60)
    return (df_prev.set_index('MPA_ID')['M'] > tau_high).astype(int)

def write_fullsample_panels(base_dir: Path, out_dir: Path, high_pctl: float, sc_gate_q: float, abs_baseline: float):
    # Load per-year inputs
    nodes = {y: load_nodes(base_dir, y) for y in YEARS}
    A = {y: load_A_prev(base_dir, y) for y in YEARS}  # tuple (A, ids)

    # 2015 labels — full sample thresholds
    df15 = compute_labels_for_year(nodes[2015], high_pctl, sc_gate_q, abs_baseline)
    y2015_full = df15[['MPA_ID','SeaRegion','province','M','SC_deg','SC_gate','Y_t']].rename(columns={'Y_t':'Y_2015'})
    y2015_full.to_csv(out_dir / "alaam" / "y2015_labels__fullsample.csv", index=False, encoding="utf-8-sig")

    # Panels for 2020/2025 (full sample)
    rows = []
    for t in PANEL_TARGET_YEARS:
        t_prev = YEARS[YEARS.index(t) - 1]
        df_t = compute_labels_for_year(nodes[t], high_pctl, sc_gate_q, abs_baseline)
        A_prev, ids_prev = A[t_prev]
        struct_prev = build_struct_features_from_prev(A_prev, ids_prev)
        high_prev = high_misfit_prev_indicator(nodes[t_prev])

        out = pd.DataFrame({
            'Year': t,
            'MPA_ID': df_t['MPA_ID'],
            'Y_t': df_t['Y_t'].astype(int),
            'SeaRegion': df_t['SeaRegion'],
            'province': df_t['province'],
            'M': df_t['M'],
            'SC_gate': df_t['SC_gate'].astype(int),
            'high_misfit_t1': df_t['MPA_ID'].map(high_prev).fillna(0).astype(int),
            'logdeg_t1': df_t['MPA_ID'].map(struct_prev['logdeg_t1']).fillna(0.0),
            'triangles_t1': df_t['MPA_ID'].map(struct_prev['triangles_t1']).fillna(0.0),
            'two_path_open_t1': df_t['MPA_ID'].map(struct_prev['two_path_open_t1']).fillna(0.0),
            'corridor_betweenness_t1': df_t['MPA_ID'].map(struct_prev['corridor_betweenness_t1']).fillna(0.0),
            'corridor_betweenness_t1_z': df_t['MPA_ID'].map(struct_prev['corridor_betweenness_t1_z']).fillna(0.0),
        })
        rows.append(out)

    panel_full = pd.concat(rows, ignore_index=True)
    panel_full.to_csv(out_dir / "alaam" / "alaam_nodes_panel__fullsample.csv", index=False, encoding="utf-8-sig")
    print("[FULLSAMPLE] wrote alaam_nodes_panel__fullsample.csv & y2015_labels__fullsample.csv")

def write_survivors2015_panels(base_dir: Path, out_dir: Path, high_pctl: float, sc_gate_q: float,
                               abs_baseline: float, strict_intersection: bool = True):
    # Load per-year inputs
    nodes = {y: load_nodes(base_dir, y) for y in YEARS}
    A = {y: load_A_prev(base_dir, y) for y in YEARS}  # tuple (A, ids)
    ids15 = set(A[2015][1]); ids20 = set(A[2020][1]); ids25 = set(A[2025][1])

    # Survivor sets
    if strict_intersection:
        S15 = sorted(list(ids15 & ids20 & ids25))
        S20 = sorted(list(ids15 & ids20))
        S25 = sorted(list(ids15 & ids25))
    else:
        S15 = sorted(list(ids15))
        S20 = sorted(list(ids15 & ids20))
        S25 = sorted(list(ids15 & ids25))

    # Labels for 2015 (full-sample thresholds; filtered to S15)
    df15_full = compute_labels_for_year(nodes[2015], high_pctl, sc_gate_q, abs_baseline)
    df15_surv = df15_full[df15_full['MPA_ID'].isin(S15)].copy()
    df15_surv[['MPA_ID','SeaRegion','province','M','SC_deg','SC_gate','Y_t']].rename(
        columns={'Y_t':'Y_2015'}
    ).to_csv(out_dir / "alaam" / "y2015_labels__survivors2015.csv", index=False, encoding="utf-8-sig")

    # Panels for 2020/2025 (full-sample thresholds computed per-year; filtered to S20/S25)
    df20_full = compute_labels_for_year(nodes[2020], high_pctl, sc_gate_q, abs_baseline)
    df25_full = compute_labels_for_year(nodes[2025], high_pctl, sc_gate_q, abs_baseline)

    rows = []
    for t, S_t in [(2020, S20), (2025, S25)]:
        t_prev = YEARS[YEARS.index(t) - 1]
        df_t = (df20_full if t == 2020 else df25_full)
        df_t = df_t[df_t['MPA_ID'].isin(S_t)].copy()

        A_prev, ids_prev = A[t_prev]
        struct_prev = build_struct_features_from_prev(A_prev, ids_prev)
        high_prev = high_misfit_prev_indicator(nodes[t_prev])

        out = pd.DataFrame({
            'Year': t,
            'MPA_ID': df_t['MPA_ID'],
            'Y_t': df_t['Y_t'].astype(int),
            'SeaRegion': df_t['SeaRegion'],
            'province': df_t['province'],
            'M': df_t['M'],
            'SC_gate': df_t['SC_gate'].astype(int),
            'high_misfit_t1': df_t['MPA_ID'].map(high_prev).fillna(0).astype(int),
            'logdeg_t1': df_t['MPA_ID'].map(struct_prev['logdeg_t1']).fillna(0.0),
            'triangles_t1': df_t['MPA_ID'].map(struct_prev['triangles_t1']).fillna(0.0),
            'two_path_open_t1': df_t['MPA_ID'].map(struct_prev['two_path_open_t1']).fillna(0.0),
            'corridor_betweenness_t1': df_t['MPA_ID'].map(struct_prev['corridor_betweenness_t1']).fillna(0.0),
            'corridor_betweenness_t1_z': df_t['MPA_ID'].map(struct_prev['corridor_betweenness_t1_z']).fillna(0.0),
        })
        rows.append(out)

    panel_surv = pd.concat(rows, ignore_index=True)
    panel_surv.to_csv(out_dir / "alaam" / "alaam_nodes_panel__survivors2015.csv", index=False, encoding="utf-8-sig")

    # Long-form tracking for descriptive summaries (S15 across 2015/2020/2025)
    base_cols = ['MPA_ID','SeaRegion','province','M','SC_deg']
    df15_long = df15_surv[base_cols].copy(); df15_long['Year'] = 2015
    df20_long = df20_full[df20_full['MPA_ID'].isin(S20)][base_cols].copy(); df20_long['Year'] = 2020
    df25_long = df25_full[df25_full['MPA_ID'].isin(S25)][base_cols].copy(); df25_long['Year'] = 2025
    long_surv = pd.concat([df15_long, df20_long, df25_long], ignore_index=True)
    long_surv = long_surv[['Year'] + base_cols]
    long_surv.to_csv(out_dir / "alaam" / "misfit_nodes_all_years__survivors2015.csv", index=False, encoding="utf-8-sig")

    # Summary counts
    summary_rows = [
        {'set':'S15', 'n':len(S15)},
        {'set':'S20', 'n':len(S20), 'Y_2020': int(panel_surv.loc[panel_surv['Year']==2020, 'Y_t'].sum())},
        {'set':'S25', 'n':len(S25), 'Y_2025': int(panel_surv.loc[panel_surv['Year']==2025, 'Y_t'].sum())},
    ]
    pd.DataFrame(summary_rows).to_csv(out_dir / "summary" / "survivors2015_counts.csv", index=False, encoding="utf-8-sig")
    print("[SURVIVORS-2015] wrote alaam_nodes_panel__survivors2015.csv / y2015_labels__survivors2015.csv / misfit_nodes_all_years__survivors2015.csv")

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--base-dir', type=str, default='./outputs_alaam', help='Base outputs dir of the main pipeline')
    ap.add_argument('--out-dir', type=str, default=None, help='Output dir for robustness panels (default: same as base-dir)')
    ap.add_argument('--high-pctl', type=float, default=0.85, help='Critical quantile per sea-region (e.g., 0.85)')
    ap.add_argument('--sc-gate-q', type=float, default=0.40, help='SC-degree gate quantile per sea-region (e.g., 0.40)')
    ap.add_argument('--abs-baseline', type=float, default=0.10, help='Minimum floor for M threshold')
    ap.add_argument('--strict-intersect', action='store_true', default=True, help='Use strict survivors (2015∩2020∩2025)')
    ap.add_argument('--no-strict-intersect', dest='strict_intersect', action='store_false', help='Looser survivors (2015 base only)')
    args = ap.parse_args()

    warnings.filterwarnings('ignore')

    base_dir = Path(args.base_dir).resolve()
    out_dir = Path(args.out_dir).resolve() if args.out_dir else base_dir

    # Ensure dirs
    for sub in ['alaam', 'summary']:
        (out_dir / sub).mkdir(parents=True, exist_ok=True)

    # 1) Full-sample re-estimation
    write_fullsample_panels(base_dir, out_dir, args.high_pctl, args.sc_gate_q, args.abs_baseline)

    # 2) Survivor-2015 time-tracking
    write_survivors2015_panels(base_dir, out_dir, args.high_pctl, args.sc_gate_q, args.abs_baseline, args.strict_intersect)

if __name__ == '__main__':
    main()

